# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch


class Matcher(object):  # Matcher,用于划分样本
    """
    This class assigns to each predicted "element" (e.g., a box) a ground-truth
    element. Each predicted element will have exactly zero or one matches; each
    ground-truth element may be assigned to zero or more predicted elements.
    这个类为每个预测的“元素”（例如，一个框）分配一个基本的真值元素。
    每一个预测元素都有精确的零个或一个匹配；
    每一个基本真值元素可以被分配给零个或多个预测元素。

    Matching is based on the MxN match_quality_matrix, that characterizes how well
    each (ground-truth, predicted)-pair match. For example, if the elements are
    boxes, the matrix may contain box IoU overlap values.
    匹配基于MxN match_quality_矩阵，
    该矩阵描述了每一对（GT，预测）匹配的程度。
    例如，如果元素是框，则矩阵可以包含框IoU重叠值。

    The matcher returns a tensor of size N containing the index of the ground-truth
    element m that matches to prediction n. If there is no match, a negative value
    is returned.
    匹配器返回一个大小为N的张量，
    其中包含与预测N匹配的基本真值元素m的索引。
    如果没有匹配，则返回一个负值。

    这个类主要实现将RPN提取出来的所有锚点(anchor)与标注的基准边框(ground truth box)进行匹配
    每一个锚点都会匹配一个与之对应的基准边框，当锚点与基准边框的Iou小于一定值时，认定其找不到对应
    的边框，认定其为背景。每一个基准边框对应０个或者多个锚点
    这个匹配操作是基于计算过的各个锚点与基准边框之间的IoU的MxN矩阵(match_quality_matrix)来
    进行的。其中M为基准边框的个数，N为锚点的个数。IoU矩阵的每一列表示某个锚点与所有各个基准框之间
    的IoU,每一行表示每个基准边框与所有各个锚点之间的IoU
    Matcher类返回一个长度为N的向量，其表示每一个锚点的类型：背景-1,介于背景和目标之间-2以及
    目标边框（*各自对应的基准边框的索引*）
    """

    BELOW_LOW_THRESHOLD = -1
    BETWEEN_THRESHOLDS = -2

    def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False):
        """
        Args:
            high_threshold (float): quality values greater than or equal to
                this value are candidate matches.
            low_threshold (float): a lower quality threshold used to stratify
                matches into three levels:
                1) matches >= high_threshold
                2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold)
                3) BELOW_LOW_THRESHOLD matches in [0, low_threshold)
            allow_low_quality_matches (bool): if True, produce additional matches
                for predictions that have only low-quality match candidates. See
                set_low_quality_matches_ for more details.
        """
        assert low_threshold <= high_threshold
        self.high_threshold = high_threshold
        self.low_threshold = low_threshold
        self.allow_low_quality_matches = allow_low_quality_matches  # matcher划分样本

    def __call__(self, match_quality_matrix):
        """
        Args:
            match_quality_matrix (Tensor[float]): an MxN tensor, containing the
            pairwise quality between M ground-truth elements and N predicted elements.
            match_quality_matrix（Tensor[float]）：MxN张量，
            包含M个GT和N个预测元素之间的成对质量。

        Returns:
            matches (Tensor[int64]): an N tensor where N[i] is a matched gt in
            [0, M - 1] or a negative value indicating that prediction i could not
            be matched.
            匹配（张量[int64]）：N张量，其中N[i]是[0，M-1]中的匹配gt，或者是表示无法匹配预测i的负值。

            match_quality_matrix，是一个MxN的矩阵，其中M为基准边框的个数，N为锚点的个数。
            IoU矩阵的每一列表示某个锚点与所有各个基准框之间的IoU,
            每一行表示每个基准边框与所有各个锚点之间的IoU
        """
        if match_quality_matrix.numel() == 0:
            # empty targets or proposals not supported during training
            if match_quality_matrix.shape[0] == 0:
                raise ValueError(
                    "No ground-truth boxes available for one of the images "
                    "during training")
            else:
                raise ValueError(
                    "No proposal boxes available for one of the images "
                    "during training")

        # match_quality_matrix is M (gt) x N (predicted)
        # Max over gt elements (dim 0) to find best gt candidate for each prediction
        # 从每一列中找到最大的IoU，即找到与锚点IoU最大的基准边框。得到Iou的值以及基准边框的索引.
        # 我这个锚框和那个GT最接近
        matched_vals, matches = match_quality_matrix.max(dim=0)
        if self.allow_low_quality_matches:
            all_matches = matches.clone()

        # Assign candidate matches with low quality to negative (unassigned) values
        below_low_threshold = matched_vals < self.low_threshold
        between_thresholds = (matched_vals >= self.low_threshold) & (
            matched_vals < self.high_threshold
        )
        # 这个用法很有趣：
        # y=torch.tensor([0,2,0,1])
        # c = x < 1     # tensor([ True, False, False, False])
        # y[c] = -1     # tensor([-1,  2,  0,  1])

        matches[below_low_threshold] = Matcher.BELOW_LOW_THRESHOLD  # -1
        matches[between_thresholds] = Matcher.BETWEEN_THRESHOLDS  # -2

        if self.allow_low_quality_matches:  # 我这个GT和哪个锚点框最接近
            self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)

        return matches

    def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix):
        """
        Produce additional matches for predictions that have only low-quality matches.
        Specifically, for each ground-truth find the set of predictions that have
        maximum overlap with it (including ties); for each prediction in that set, if
        it is unmatched, then match it to the ground-truth with which it has the highest
        quality value.
        为只有低质量匹配的预测生成额外的匹配。
        具体来说，对于每个GT，
        找到与之最大重叠的预测集（包括关联）；
        对于该集中的每个预测，如果不匹配，则将其与具有最高质量值的GT匹配。
        """
        # For each gt, find the prediction with which it has highest quality
        highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
        # Find highest quality match available, even if it is low, including ties
        gt_pred_pairs_of_highest_quality = torch.nonzero(
            match_quality_matrix == highest_quality_foreach_gt[:, None]
        )
        # Example gt_pred_pairs_of_highest_quality:
        #   tensor([[    0, 39796],
        #           [    1, 32055],
        #           [    1, 32070],
        #           [    2, 39190],
        #           [    2, 40255],
        #           [    3, 40390],
        #           [    3, 41455],
        #           [    4, 45470],
        #           [    5, 45325],
        #           [    5, 46390]])
        # Each row is a (gt index, prediction index)
        # Note how gt items 1, 2, 3, and 5 each have two ties

        pred_inds_to_update = gt_pred_pairs_of_highest_quality[:, 1]
        matches[pred_inds_to_update] = all_matches[pred_inds_to_update]
